{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MhoQ0WE77laV"
   },
   "source": [
    "##### Copyright 2018 The TensorFlow Authors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2023-11-08T00:31:57.582337Z",
     "iopub.status.busy": "2023-11-08T00:31:57.582061Z",
     "iopub.status.idle": "2023-11-08T00:31:57.586462Z",
     "shell.execute_reply": "2023-11-08T00:31:57.585790Z"
    },
    "id": "_ckMIh7O7s6D"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {
     "iopub.execute_input": "2023-11-08T00:31:57.589670Z",
     "iopub.status.busy": "2023-11-08T00:31:57.589429Z",
     "iopub.status.idle": "2023-11-08T00:31:57.593266Z",
     "shell.execute_reply": "2023-11-08T00:31:57.592623Z"
    },
    "id": "vasWnqRgy1H4"
   },
   "outputs": [],
   "source": [
    "#@title MIT License\n",
    "#\n",
    "# Copyright (c) 2017 François Chollet\n",
    "#\n",
    "# Permission is hereby granted, free of charge, to any person obtaining a\n",
    "# copy of this software and associated documentation files (the \"Software\"),\n",
    "# to deal in the Software without restriction, including without limitation\n",
    "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n",
    "# and/or sell copies of the Software, and to permit persons to whom the\n",
    "# Software is furnished to do so, subject to the following conditions:\n",
    "#\n",
    "# The above copyright notice and this permission notice shall be included in\n",
    "# all copies or substantial portions of the Software.\n",
    "#\n",
    "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
    "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
    "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n",
    "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
    "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n",
    "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n",
    "# DEALINGS IN THE SOFTWARE."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jYysdyb-CaWM"
   },
   "source": [
    "# 基本分类：对服装图像进行分类"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FbVhjPpzn6BM"
   },
   "source": [
    "本指南将训练一个神经网络模型，对运动鞋和衬衫等服装图像进行分类。即使您不理解所有细节也没关系；这只是对完整 TensorFlow 程序的快速概述，详细内容会在您实际操作的同时进行介绍。\n",
    "\n",
    "本指南使用了 [tf.keras](https://tensorflow.google.cn/guide/keras)，它是 TensorFlow 中用来构建和训练模型的高级 API。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:31:57.596699Z",
     "iopub.status.busy": "2023-11-08T00:31:57.596464Z",
     "iopub.status.idle": "2023-11-08T00:32:00.493416Z",
     "shell.execute_reply": "2023-11-08T00:32:00.492593Z"
    },
    "id": "dzLKpmZICaWN"
   },
   "outputs": [],
   "source": [
    "# TensorFlow and tf.keras\n",
    "import tensorflow as tf\n",
    "\n",
    "# Helper libraries\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yR0EdgrLCaWR"
   },
   "source": [
    "## 导入 Fashion MNIST 数据集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DLdCchMdCaWQ"
   },
   "source": [
    "本指南使用 [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist) 数据集，该数据集包含 10 个类别的 70,000 个灰度图像。这些图像以低分辨率（28x28 像素）展示了单件衣物，如下所示：\n",
    "\n",
    "<table>\n",
    "  <tr><td>     <img alt=\"Fashion MNIST sprite\" src=\"https://tensorflow.google.cn/images/fashion-mnist-sprite.png\"> </td></tr>\n",
    "  <tr><td align=\"center\">     <b>图 1.</b>  <a href=\"https://github.com/zalandoresearch/fashion-mnist\">Fashion-MNIST 样本</a>（由 Zalando 提供，MIT 许可）。<br> </td></tr>\n",
    "</table>\n",
    "\n",
    "Fashion MNIST 旨在临时替代经典 [MNIST](http://yann.lecun.com/exdb/mnist/) 数据集，后者常被用作计算机视觉机器学习程序的“Hello, World”。MNIST 数据集包含手写数字（0、1、2 等）的图像，其格式与您将使用的衣物图像的格式相同。\n",
    "\n",
    "本指南使用 Fashion MNIST 来实现多样化，因为它比常规 MNIST 更具挑战性。这两个数据集都相对较小，都用于验证某个算法是否按预期工作。对于代码的测试和调试，它们都是很好的起点。\n",
    "\n",
    "在本指南中，我们使用 60,000 张图像来训练网络，使用 10,000 张图像来评估网络学习对图像进行分类的准确程度。您可以直接从 TensorFlow 中访问 Fashion MNIST。直接从 TensorFlow 中导入和加载 Fashion MNIST 数据："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:00.497855Z",
     "iopub.status.busy": "2023-11-08T00:32:00.496948Z",
     "iopub.status.idle": "2023-11-08T00:32:00.961685Z",
     "shell.execute_reply": "2023-11-08T00:32:00.960688Z"
    },
    "id": "7MqDQO0KCaWS"
   },
   "outputs": [],
   "source": [
    "fashion_mnist = tf.keras.datasets.fashion_mnist\n",
    "\n",
    "(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "t9FDsUlxCaWW"
   },
   "source": [
    "加载数据集会返回四个 NumPy 数组：\n",
    "\n",
    "- `train_images` 和 `train_labels` 数组是*训练集*，即模型用于学习的数据。\n",
    "- *测试集*、`test_images` 和 `test_labels` 数组会被用来对模型进行测试。\n",
    "\n",
    "图像是 28x28 的 NumPy 数组，像素值介于 0 到 255 之间。*标签*是整数数组，介于 0 到 9 之间。这些标签对应于图像所代表的服装*类*：\n",
    "\n",
    "<table>\n",
    "  <tr>\n",
    "    <th>标签</th>\n",
    "    <th>类</th>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td>0</td>\n",
    "    <td>T恤/上衣</td>\n",
    "  </tr>\n",
    "  <tr>\n",
    "    <td>1</td>\n",
    "    <td>裤子</td>\n",
    "  </tr>\n",
    "    <tr>\n",
    "    <td>2</td>\n",
    "    <td>套头衫</td>\n",
    "  </tr>\n",
    "    <tr>\n",
    "    <td>3</td>\n",
    "    <td>连衣裙</td>\n",
    "  </tr>\n",
    "    <tr>\n",
    "    <td>4</td>\n",
    "    <td>外套</td>\n",
    "  </tr>\n",
    "    <tr>\n",
    "    <td>5</td>\n",
    "    <td>凉鞋</td>\n",
    "  </tr>\n",
    "    <tr>\n",
    "    <td>6</td>\n",
    "    <td>衬衫</td>\n",
    "  </tr>\n",
    "    <tr>\n",
    "    <td>7</td>\n",
    "    <td>运动鞋</td>\n",
    "  </tr>\n",
    "    <tr>\n",
    "    <td>8</td>\n",
    "    <td>包</td>\n",
    "  </tr>\n",
    "    <tr>\n",
    "    <td>9</td>\n",
    "    <td>短靴</td>\n",
    "  </tr>\n",
    "</table>\n",
    "\n",
    "每个图像都会被映射到一个标签。由于数据集不包括*类名称*，请将它们存储在下方，供稍后绘制图像时使用："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:00.966946Z",
     "iopub.status.busy": "2023-11-08T00:32:00.966224Z",
     "iopub.status.idle": "2023-11-08T00:32:00.970181Z",
     "shell.execute_reply": "2023-11-08T00:32:00.969466Z"
    },
    "id": "IjnLH5S2CaWx"
   },
   "outputs": [],
   "source": [
    "class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n",
    "               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Brm0b_KACaWX"
   },
   "source": [
    "## 浏览数据\n",
    "\n",
    "在训练模型之前，我们先浏览一下数据集的格式。以下代码显示训练集中有 60,000 个图像，每个图像由 28 x 28 的像素表示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:00.973607Z",
     "iopub.status.busy": "2023-11-08T00:32:00.973340Z",
     "iopub.status.idle": "2023-11-08T00:32:00.980293Z",
     "shell.execute_reply": "2023-11-08T00:32:00.979620Z"
    },
    "id": "zW5k_xz1CaWX"
   },
   "outputs": [],
   "source": [
    "train_images.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cIAcvQqMCaWf"
   },
   "source": [
    "同样，训练集中有 60,000 个标签："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:00.983756Z",
     "iopub.status.busy": "2023-11-08T00:32:00.983484Z",
     "iopub.status.idle": "2023-11-08T00:32:00.987678Z",
     "shell.execute_reply": "2023-11-08T00:32:00.987091Z"
    },
    "id": "TRFYHB2mCaWb"
   },
   "outputs": [],
   "source": [
    "len(train_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YSlYxFuRCaWk"
   },
   "source": [
    "每个标签都是一个 0 到 9 之间的整数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:00.990819Z",
     "iopub.status.busy": "2023-11-08T00:32:00.990571Z",
     "iopub.status.idle": "2023-11-08T00:32:00.995193Z",
     "shell.execute_reply": "2023-11-08T00:32:00.994575Z"
    },
    "id": "XKnCTHz4CaWg"
   },
   "outputs": [],
   "source": [
    "train_labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TMPI88iZpO2T"
   },
   "source": [
    "测试集中有 10,000 个图像。同样，每个图像都由 28x28 个像素表示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:00.998525Z",
     "iopub.status.busy": "2023-11-08T00:32:00.998271Z",
     "iopub.status.idle": "2023-11-08T00:32:01.002505Z",
     "shell.execute_reply": "2023-11-08T00:32:01.001852Z"
    },
    "id": "2KFnYlcwCaWl"
   },
   "outputs": [],
   "source": [
    "test_images.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rd0A0Iu0CaWq"
   },
   "source": [
    "测试集包含 10,000 个图像标签："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:01.005749Z",
     "iopub.status.busy": "2023-11-08T00:32:01.005514Z",
     "iopub.status.idle": "2023-11-08T00:32:01.009770Z",
     "shell.execute_reply": "2023-11-08T00:32:01.009189Z"
    },
    "id": "iJmPr5-ACaWn"
   },
   "outputs": [],
   "source": [
    "len(test_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ES6uQoLKCaWr"
   },
   "source": [
    "## 预处理数据\n",
    "\n",
    "在训练网络之前，必须对数据进行预处理。如果您检查训练集中的第一个图像，您会看到像素值处于 0 到 255 之间："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:01.013242Z",
     "iopub.status.busy": "2023-11-08T00:32:01.012605Z",
     "iopub.status.idle": "2023-11-08T00:32:01.206566Z",
     "shell.execute_reply": "2023-11-08T00:32:01.205852Z"
    },
    "id": "m4VEw8Ud9Quh"
   },
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.imshow(train_images[0])\n",
    "plt.colorbar()\n",
    "plt.grid(False)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Wz7l27Lz9S1P"
   },
   "source": [
    "将这些值缩小至 0 到 1 之间，然后将其馈送到神经网络模型。为此，请将这些值除以 255。请务必以相同的方式对*训练集*和*测试集*进行预处理："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:01.210838Z",
     "iopub.status.busy": "2023-11-08T00:32:01.210205Z",
     "iopub.status.idle": "2023-11-08T00:32:01.405851Z",
     "shell.execute_reply": "2023-11-08T00:32:01.404750Z"
    },
    "id": "bW5WzIPlCaWv"
   },
   "outputs": [],
   "source": [
    "train_images = train_images / 255.0\n",
    "\n",
    "test_images = test_images / 255.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ee638AlnCaWz"
   },
   "source": [
    "为了验证数据的格式是否正确，以及您是否已准备好构建和训练网络，让我们显示*训练集*中的前 25 个图像，并在每个图像下方显示类名称。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:01.410740Z",
     "iopub.status.busy": "2023-11-08T00:32:01.409959Z",
     "iopub.status.idle": "2023-11-08T00:32:02.283088Z",
     "shell.execute_reply": "2023-11-08T00:32:02.282325Z"
    },
    "id": "oZTImqg_CaW1"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "for i in range(25):\n",
    "    plt.subplot(5,5,i+1)\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    plt.grid(False)\n",
    "    plt.imshow(train_images[i], cmap=plt.cm.binary)\n",
    "    plt.xlabel(class_names[train_labels[i]])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "59veuiEZCaW4"
   },
   "source": [
    "## 构建模型\n",
    "\n",
    "构建神经网络需要先配置模型的层，然后再编译模型。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Gxg1XGm0eOBy"
   },
   "source": [
    "### 设置层\n",
    "\n",
    "神经网络的基本组成部分是<em>层</em>。层会从向其馈送的数据中提取表示形式。希望这些表示形式有助于解决手头上的问题。\n",
    "\n",
    "大多数深度学习都包括将简单的层链接在一起。大多数层（如 `tf.keras.layers.Dense`）都具有在训练期间才会学习的参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:02.287662Z",
     "iopub.status.busy": "2023-11-08T00:32:02.287010Z",
     "iopub.status.idle": "2023-11-08T00:32:04.609753Z",
     "shell.execute_reply": "2023-11-08T00:32:04.608927Z"
    },
    "id": "9ODch-OFCaW4"
   },
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential([\n",
    "    tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
    "    tf.keras.layers.Dense(128, activation='relu'),\n",
    "    tf.keras.layers.Dense(10)\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gut8A_7rCaW6"
   },
   "source": [
    "该网络的第一层 `tf.keras.layers.Flatten` 将图像格式从二维数组（28 x 28 像素）转换成一维数组（28 x 28 = 784 像素）。将该层视为图像中未堆叠的像素行并将其排列起来。该层没有要学习的参数，它只会重新格式化数据。\n",
    "\n",
    "展平像素后，网络会包括两个 `tf.keras.layers.Dense` 层的序列。它们是密集连接或全连接神经层。第一个 `Dense` 层有 128 个节点（或神经元）。第二个（也是最后一个）层会返回一个长度为 10 的 logits 数组。每个节点都包含一个得分，用来表示当前图像属于 10 个类中的哪一类。\n",
    "\n",
    "### 编译模型\n",
    "\n",
    "在准备对模型进行训练之前，还需要再对其进行一些设置。以下内容是在模型的<em>编译</em>步骤中添加的：\n",
    "\n",
    "- <em>损失函数</em> - 测量模型在训练期间的准确程度。你希望最小化此函数，以便将模型“引导”到正确的方向上。\n",
    "- <em>优化器</em> - 决定模型如何根据其看到的数据和自身的损失函数进行更新。\n",
    "- <em>指标</em> - 用于监控训练和测试步骤。以下示例使用了*准确率*，即被正确分类的图像的比率。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:04.614187Z",
     "iopub.status.busy": "2023-11-08T00:32:04.613898Z",
     "iopub.status.idle": "2023-11-08T00:32:04.630551Z",
     "shell.execute_reply": "2023-11-08T00:32:04.629851Z"
    },
    "id": "Lhan11blCaW7"
   },
   "outputs": [],
   "source": [
    "model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qKF6uW-BCaW-"
   },
   "source": [
    "## 训练模型\n",
    "\n",
    "训练神经网络模型需要执行以下步骤：\n",
    "\n",
    "1. 将训练数据馈送给模型。在本例中，训练数据位于 `train_images` 和 `train_labels` 数组中。\n",
    "2. 模型学习将图像和标签关联起来。\n",
    "3. 要求模型对测试集（在本例中为 `test_images` 数组）进行预测。\n",
    "4. 验证预测是否与 `test_labels` 数组中的标签相匹配。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Z4P4zIV7E28Z"
   },
   "source": [
    "### 向模型馈送数据\n",
    "\n",
    "要开始训练，请调用 <code>model.fit</code> 方法，这样命名是因为该方法会将模型与训练数据进行“拟合”："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:04.634748Z",
     "iopub.status.busy": "2023-11-08T00:32:04.634068Z",
     "iopub.status.idle": "2023-11-08T00:32:48.326687Z",
     "shell.execute_reply": "2023-11-08T00:32:48.325933Z"
    },
    "id": "xvwvpA64CaW_"
   },
   "outputs": [],
   "source": [
    "model.fit(train_images, train_labels, epochs=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "W3ZVOhugCaXA"
   },
   "source": [
    "在模型训练期间，会显示损失和准确率指标。此模型在训练数据上的准确率达到了 0.91（或 91%）左右。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wCpr6DGyE28h"
   },
   "source": [
    "### 评估准确率\n",
    "\n",
    "接下来，比较模型在测试数据集上的表现："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:48.330553Z",
     "iopub.status.busy": "2023-11-08T00:32:48.330234Z",
     "iopub.status.idle": "2023-11-08T00:32:49.146329Z",
     "shell.execute_reply": "2023-11-08T00:32:49.145520Z"
    },
    "id": "VflXLEeECaXC"
   },
   "outputs": [],
   "source": [
    "test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)\n",
    "\n",
    "print('\\nTest accuracy:', test_acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yWfgsmVXCaXG"
   },
   "source": [
    "结果表明，模型在测试数据集上的准确率略低于训练数据集。训练准确率和测试准确率之间的差距代表*过拟合*。过拟合是指机器学习模型在新的、以前未曾见过的输入上的表现不如在训练数据上的表现。过拟合的模型会“记住”训练数据集中的噪声和细节，从而对模型在新数据上的表现产生负面影响。有关更多信息，请参阅以下内容：\n",
    "\n",
    "- [演示过拟合](https://tensorflow.google.cn/tutorials/keras/overfit_and_underfit#demonstrate_overfitting)\n",
    "- [防止过拟合的策略](https://tensorflow.google.cn/tutorials/keras/overfit_and_underfit#strategies_to_prevent_overfitting)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "v-PyD1SYE28q"
   },
   "source": [
    "### 进行预测\n",
    "\n",
    "模型经过训练后，您可以使用它对一些图像进行预测。附加一个 Softmax 层，将模型的线性输出 [logits](https://developers.google.com/machine-learning/glossary#logits) 转换成更容易理解的概率。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:49.150544Z",
     "iopub.status.busy": "2023-11-08T00:32:49.149962Z",
     "iopub.status.idle": "2023-11-08T00:32:49.173240Z",
     "shell.execute_reply": "2023-11-08T00:32:49.172536Z"
    },
    "id": "DnfNA0CrQLSD"
   },
   "outputs": [],
   "source": [
    "probability_model = tf.keras.Sequential([model, \n",
    "                                         tf.keras.layers.Softmax()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:49.176902Z",
     "iopub.status.busy": "2023-11-08T00:32:49.176626Z",
     "iopub.status.idle": "2023-11-08T00:32:49.985234Z",
     "shell.execute_reply": "2023-11-08T00:32:49.984263Z"
    },
    "id": "Gl91RPhdCaXI"
   },
   "outputs": [],
   "source": [
    "predictions = probability_model.predict(test_images)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "x9Kk1voUCaXJ"
   },
   "source": [
    "在上例中，模型预测了测试集中每个图像的标签。我们来看看第一个预测结果："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:49.989969Z",
     "iopub.status.busy": "2023-11-08T00:32:49.989372Z",
     "iopub.status.idle": "2023-11-08T00:32:49.995194Z",
     "shell.execute_reply": "2023-11-08T00:32:49.994464Z"
    },
    "id": "3DmJEUinCaXK"
   },
   "outputs": [],
   "source": [
    "predictions[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-hw1hgeSCaXN"
   },
   "source": [
    "预测结果是一个包含 10 个数字的数组。它们代表模型对 10 种不同服装中每种服装的“置信度”。您可以看到哪个标签的置信度值最大："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:49.998786Z",
     "iopub.status.busy": "2023-11-08T00:32:49.998515Z",
     "iopub.status.idle": "2023-11-08T00:32:50.003623Z",
     "shell.execute_reply": "2023-11-08T00:32:50.002926Z"
    },
    "id": "qsqenuPnCaXO"
   },
   "outputs": [],
   "source": [
    "np.argmax(predictions[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "E51yS7iCCaXO"
   },
   "source": [
    "因此，该模型非常确信这个图像是短靴，或 `class_names[9]`。通过检查测试标签发现这个分类是正确的："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:50.006956Z",
     "iopub.status.busy": "2023-11-08T00:32:50.006684Z",
     "iopub.status.idle": "2023-11-08T00:32:50.011485Z",
     "shell.execute_reply": "2023-11-08T00:32:50.010792Z"
    },
    "id": "Sd7Pgsu6CaXP"
   },
   "outputs": [],
   "source": [
    "test_labels[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ygh2yYC972ne"
   },
   "source": [
    "您可以将其绘制成图表，看看模型对于全部 10 个类的预测。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:50.014800Z",
     "iopub.status.busy": "2023-11-08T00:32:50.014540Z",
     "iopub.status.idle": "2023-11-08T00:32:50.021354Z",
     "shell.execute_reply": "2023-11-08T00:32:50.020627Z"
    },
    "id": "DvYmmrpIy6Y1"
   },
   "outputs": [],
   "source": [
    "def plot_image(i, predictions_array, true_label, img):\n",
    "  true_label, img = true_label[i], img[i]\n",
    "  plt.grid(False)\n",
    "  plt.xticks([])\n",
    "  plt.yticks([])\n",
    "\n",
    "  plt.imshow(img, cmap=plt.cm.binary)\n",
    "\n",
    "  predicted_label = np.argmax(predictions_array)\n",
    "  if predicted_label == true_label:\n",
    "    color = 'blue'\n",
    "  else:\n",
    "    color = 'red'\n",
    "\n",
    "  plt.xlabel(\"{} {:2.0f}% ({})\".format(class_names[predicted_label],\n",
    "                                100*np.max(predictions_array),\n",
    "                                class_names[true_label]),\n",
    "                                color=color)\n",
    "\n",
    "def plot_value_array(i, predictions_array, true_label):\n",
    "  true_label = true_label[i]\n",
    "  plt.grid(False)\n",
    "  plt.xticks(range(10))\n",
    "  plt.yticks([])\n",
    "  thisplot = plt.bar(range(10), predictions_array, color=\"#777777\")\n",
    "  plt.ylim([0, 1])\n",
    "  predicted_label = np.argmax(predictions_array)\n",
    "\n",
    "  thisplot[predicted_label].set_color('red')\n",
    "  thisplot[true_label].set_color('blue')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Zh9yABaME29S"
   },
   "source": [
    "### 验证预测结果\n",
    "\n",
    "在模型经过训练后，您可以使用它对一些图像进行预测。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "d4Ov9OFDMmOD"
   },
   "source": [
    "我们来看看第 0 个图像、预测结果和预测数组。正确的预测标签为蓝色，错误的预测标签为红色。数字表示预测标签的百分比（总计为 100）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:50.025340Z",
     "iopub.status.busy": "2023-11-08T00:32:50.024642Z",
     "iopub.status.idle": "2023-11-08T00:32:50.147358Z",
     "shell.execute_reply": "2023-11-08T00:32:50.146601Z"
    },
    "id": "HV5jw-5HwSmO"
   },
   "outputs": [],
   "source": [
    "i = 0\n",
    "plt.figure(figsize=(6,3))\n",
    "plt.subplot(1,2,1)\n",
    "plot_image(i, predictions[i], test_labels, test_images)\n",
    "plt.subplot(1,2,2)\n",
    "plot_value_array(i, predictions[i],  test_labels)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:50.150818Z",
     "iopub.status.busy": "2023-11-08T00:32:50.150540Z",
     "iopub.status.idle": "2023-11-08T00:32:50.276300Z",
     "shell.execute_reply": "2023-11-08T00:32:50.275469Z"
    },
    "id": "Ko-uzOufSCSe"
   },
   "outputs": [],
   "source": [
    "i = 12\n",
    "plt.figure(figsize=(6,3))\n",
    "plt.subplot(1,2,1)\n",
    "plot_image(i, predictions[i], test_labels, test_images)\n",
    "plt.subplot(1,2,2)\n",
    "plot_value_array(i, predictions[i],  test_labels)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kgdvGD52CaXR"
   },
   "source": [
    "让我们用模型的预测绘制几张图像。请注意，即使置信度很高，模型也可能出错。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:50.280404Z",
     "iopub.status.busy": "2023-11-08T00:32:50.279613Z",
     "iopub.status.idle": "2023-11-08T00:32:52.201082Z",
     "shell.execute_reply": "2023-11-08T00:32:52.200200Z"
    },
    "id": "hQlnbqaw2Qu_"
   },
   "outputs": [],
   "source": [
    "# Plot the first X test images, their predicted labels, and the true labels.\n",
    "# Color correct predictions in blue and incorrect predictions in red.\n",
    "num_rows = 5\n",
    "num_cols = 3\n",
    "num_images = num_rows*num_cols\n",
    "plt.figure(figsize=(2*2*num_cols, 2*num_rows))\n",
    "for i in range(num_images):\n",
    "  plt.subplot(num_rows, 2*num_cols, 2*i+1)\n",
    "  plot_image(i, predictions[i], test_labels, test_images)\n",
    "  plt.subplot(num_rows, 2*num_cols, 2*i+2)\n",
    "  plot_value_array(i, predictions[i], test_labels)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "R32zteKHCaXT"
   },
   "source": [
    "## 使用训练好的模型\n",
    "\n",
    "最后，使用训练好的模型对单个图像进行预测。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:52.205791Z",
     "iopub.status.busy": "2023-11-08T00:32:52.205507Z",
     "iopub.status.idle": "2023-11-08T00:32:52.209770Z",
     "shell.execute_reply": "2023-11-08T00:32:52.209071Z"
    },
    "id": "yRJ7JU7JCaXT"
   },
   "outputs": [],
   "source": [
    "# Grab an image from the test dataset.\n",
    "img = test_images[1]\n",
    "\n",
    "print(img.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vz3bVp21CaXV"
   },
   "source": [
    "`tf.keras` 模型经过了优化，可同时对一个*批*或一组样本进行预测。因此，即便您只使用一个图像，您也需要将其添加到列表中："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:52.213041Z",
     "iopub.status.busy": "2023-11-08T00:32:52.212781Z",
     "iopub.status.idle": "2023-11-08T00:32:52.216838Z",
     "shell.execute_reply": "2023-11-08T00:32:52.216138Z"
    },
    "id": "lDFh5yF_CaXW"
   },
   "outputs": [],
   "source": [
    "# Add the image to a batch where it's the only member.\n",
    "img = (np.expand_dims(img,0))\n",
    "\n",
    "print(img.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EQ5wLTkcCaXY"
   },
   "source": [
    "现在预测这个图像的正确标签："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:52.220259Z",
     "iopub.status.busy": "2023-11-08T00:32:52.219645Z",
     "iopub.status.idle": "2023-11-08T00:32:52.316966Z",
     "shell.execute_reply": "2023-11-08T00:32:52.316168Z"
    },
    "id": "o_rzNSdrCaXY"
   },
   "outputs": [],
   "source": [
    "predictions_single = probability_model.predict(img)\n",
    "\n",
    "print(predictions_single)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:52.320682Z",
     "iopub.status.busy": "2023-11-08T00:32:52.319945Z",
     "iopub.status.idle": "2023-11-08T00:32:52.411625Z",
     "shell.execute_reply": "2023-11-08T00:32:52.410885Z"
    },
    "id": "6Ai-cpLjO-3A"
   },
   "outputs": [],
   "source": [
    "plot_value_array(1, predictions_single[0], test_labels)\n",
    "_ = plt.xticks(range(10), class_names, rotation=45)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cU1Y2OAMCaXb"
   },
   "source": [
    "`keras.Model.predict` 会返回一组列表，每个列表对应一批数据中的每个图像。在批次中获取对我们（唯一）图像的预测："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-08T00:32:52.415511Z",
     "iopub.status.busy": "2023-11-08T00:32:52.414928Z",
     "iopub.status.idle": "2023-11-08T00:32:52.419907Z",
     "shell.execute_reply": "2023-11-08T00:32:52.419239Z"
    },
    "id": "2tRmdq_8CaXb"
   },
   "outputs": [],
   "source": [
    "np.argmax(predictions_single[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YFc2HbEVCaXd"
   },
   "source": [
    "该模型会按照预期预测标签。"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "classification.ipynb",
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
